Title
INTERPRETING MACHINE LEARNING MODEL PREDICTIONS
The age of AI in the real world has been here a while now, and they've achieved stunning advances in varied fields, breaking barriers no one thought they would.
But taking a closer look, we realize that while the era of AI is driven by building and using it, not a lot of progress has been made on understanding it's predictions. You know how to make it work, but you don't know why it works the way it does.
A model can predict the price of a stock to an impressive degree of accuracy, but it never makes it clear why it predicted \$15, not \\$16 or \$14. The decision process behind a model is essentially opaque, especially in modern data science.
Two major reasons for this blindspot has been
- The massive of growth in the usage of Artificial Neural Networks(ANN), and it's associated techniques of machine learning(also known as 'Deep Learning')which are incredible algorithms that achieve fantastic results. Deep Learning and Neural Networks,leverage the recent rise in GPU access to train powerful and complex models (essentially huge matrices highly tuned using massive amounts of data). These models are the ones powering AI voice assistants, translators, facial recognition applications and many other cutting edge technologies.
The parameters in an ANN, which are used to for it's predictions, are not set by the designers themselves,but are learnt by it,after parsing the data. Due to their abstract nature,ANNs are essentially black-box machines,and while they make great predictions in their narrow domain of learning, their internal reasoning remains a mystery to builders and users.
- The inability of most organizations to grasp the necessity of understanding how their models work, something that never comes up until the models starts behaving in eways that are counterpredictive to the organization's goals.
But as data scientist and author Cathy O'Neil said "models are opinions embedded in math".
So models carry with them things that the data scientists didn't intend, like biases it read from the data,unkown biases of the data scientist etc. So if a professors in a college course tends to grade men higher than women,knowingly or unkowingly, the model will learn that bias and predict men as being more likely to succeed in that course.
As models go rogue and cause massive losses to companies for inexplicable reasons, or make unethical judgements/predictions that disproportionately hurt disadvantaged people, the need to understand why a model makes the prediction it does is rising rapidly across the world. This has to led to researchers and problem solvers across the world to come up with techniques that allow the builders and users of AI to make more of sense of what is going in the model, that makes it behave the way it does. These ideas together form the area of model explainabilty, the field that tries to understand how the parameters in a Neural Network impact it's eventual prediction.
In this, and the next few blog posts, we will look into some of these techniques. Understanding this blog just requires a basic idea of how AI predicts things. Blogs following this one will get progressively more technical, for those who are interested in getting into the weeds.
Ok, let's get to it.
As an aside, there is a concept called model interpretability in data science. Interpretability has to do with how accurate a machine learning model can associate a cause to an effect. Explainability has to do with the ability of the parameters, often hidden in Deep Nets, to justify the results. A comprehensive dive into these concepts, their contrasts and overlaps can be found here
Not all models are created equal.
Some models, like Random Forests and Linear Regression, are easily interpretable models, and the decisions the models make when they see a piece of data can be understood with little effort. These are called glassbox models, models that we can look into to understand what's happening inside.
That isn't the case with Neural Networks though. These are called blackbox models,since we can't really understand what's going on inside it. clever methods have been introduced to go around them and understand them though. Let's start with the first one.
Model Agnostic methods
Surrogate models
A Surrogate model is an interpretable model that is trained to approximate the predictions of a black box model. We can draw conclusions about the black box model by interpreting the surrogate model.
So for example, we train a linear regression/random forest model to act as a surrogate model for a neural network that predicts housing prices, the target for the random forest isn't the actual housing prices. Instead it is the predicted output of the neural network. So the job of the surrogate model is not to be good on real world data, only to be good in mimicking/replicating the output from the neural network. (consequently, it is the Neural Network's job to actually work in the real world)
We link the output of the surrogate model and the blackbox model via some statistical measure like R^2.
Some observations.
- This is an intuitive and simple workaround.
- You have to be aware that you draw conclusions about the model and not about the data, since the surrogate model never sees the real outcome.
- It is not clear what the best cut-off for R-squared is in order to be confident that the surrogate model is close enough to the black box model. Is 95% good enough ? for the stock market? for disease spread modeling? These are vague choices with no good answers.
- The whole approach is unsatifactory in a way, since you actually have no idea how the black box model is making predictions, we have to assume (rightly?) that it somehow mimicking the same decision patterns that the surrogate model is making.
Nevertheless, we persevere.
While training the surrogate model, If we weight the data locally by a specific instance of the data (the closer the instances to the selected instance of interest, the higher their weight), we get a local surrogate model that can explain the individual prediction of the instance. This is the concept behind LIME.
LIME?
LIME : Local interpretable model-agnostic explanation
Instead of training a global surrogate model, LIME focuses on training local surrogate models to explain individual predictions. It works something like this.
- Take the input features and the prediction produced by a blackbox model you want to understand
- perturb the datapoint(create a bunch of datapoints from the original one by varying it's features a little bit) to get a new mini dataset.
- Weight each new datapoint by how close it is to the original one, by some statistical measure
- Train the surrogate model on this new set of weighted data
- Intepret the surrogate model
OK. Let's check out some code. We're going to use a pre-trained neural network that predicts housing prices on the infamous Boston Housing Dataset.
import warnings
warnings.filterwarnings('ignore')
import pandas as pd
df = pd.read_csv('data/boston.csv') #read the data in, have a look
df.head()
A brief overview of what each feauture means can be found here
Now let's pick a row of data,pass it through the model to get a prediction and then pass it through the OSS Lime exaplainer, which will then use a surrogate linear model to explain how each feature effected the prediction.
example = df.iloc[122,:-1]
example
Load up the pre-trained Keras Model
from tensorflow.keras.models import load_model
model = load_model('models/boston_keras_reg.tf')
prediction = model.predict(example.values.reshape(1, -1)) #prediction
prediction
from lime.lime_tabular import LimeTabularExplainer
explainer = LimeTabularExplainer(df.drop('target',axis=1).values,
mode="regression",
feature_names=df.columns.tolist()[:-1],
discretize_continuous=False)
The parameters passed to the explainer are:
- our training set, we need to make sure we use the training set without one hot encoding
-
mode: the explainer can be used for classification or regression -
feature_names: list of labels for our features -
categorical_features: list of indexes of categorical features -
categorical_names: dict mapping each index of categorical feature to a list of corresponding labels -
dicretize_continuous: will discretize numerical values into buckets that can be used for explanation. For instance it can tell us that the decision was made because distance is in bucket [5km, 10km] instead of telling us distance is an importante feature.
explanation = explainer.explain_instance(example,model.predict)
explanation.show_in_notebook()